{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)\n",
    "\n",
    "![title](ctc.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "pip3 install librosa\n",
    "pip3 install bs4\n",
    "\"\"\"\n",
    "\n",
    "from urllib.request import urlopen, urlretrieve\n",
    "from bs4 import BeautifulSoup\n",
    "from tqdm import tqdm\n",
    "\n",
    "import re\n",
    "import os\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import librosa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "PARAMS = {\n",
    "    'num_epochs': 100,\n",
    "    'batch_size': 30,\n",
    "    'clip_norm': 5.0,\n",
    "    'winstep': 0.01,\n",
    "    'n_mfcc': 39,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def download():\n",
    "    prefix = 'https://tspace.library.utoronto.ca'\n",
    "    save_dir = './data/'\n",
    "    if not os.path.exists(save_dir):\n",
    "        os.makedirs(save_dir)\n",
    "\n",
    "    base_url = 'https://tspace.library.utoronto.ca/handle/1807/24'\n",
    "    urls = [base_url+str(i) for i in range(488, 502)]\n",
    "\n",
    "    for url in urls:\n",
    "        soup = BeautifulSoup(urlopen(url).read(), 'html5lib')\n",
    "        targets = soup.findAll('a', href=re.compile(r'/bitstream/.*.wav'))\n",
    "        \n",
    "        for a in tqdm(targets, total=len(targets), ncols=70):\n",
    "            link = a['href']\n",
    "\n",
    "            audio_save_loc = save_dir + link.split('/')[-1]\n",
    "            if os.path.isfile(audio_save_loc):\n",
    "                print(\"File Already Exists\")\n",
    "            urlretrieve(prefix+a['href'], audio_save_loc)\n",
    "\n",
    "            with open(audio_save_loc.replace('.wav', '.txt'), 'w') as f:\n",
    "                f.write('say the word ' + link.split('_')[-2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sparse_tuple_from(sequences, dtype=np.int32):\n",
    "    \"\"\"Create a sparse representention of x.\n",
    "    Args:\n",
    "        sequences: a list of lists of type dtype where each element is a sequence\n",
    "    Returns:\n",
    "        A tuple with (indices, values, shape)\n",
    "    \"\"\"\n",
    "    indices = []\n",
    "    values = []\n",
    "\n",
    "    for n, seq in enumerate(sequences):\n",
    "        indices.extend(zip([n]*len(seq), range(len(seq))))\n",
    "        values.extend(seq)\n",
    "\n",
    "    indices = np.asarray(indices, dtype=np.int64)\n",
    "    values = np.asarray(values, dtype=dtype)\n",
    "    shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1]+1], dtype=np.int64)\n",
    "\n",
    "    return (indices, values, shape)\n",
    "\n",
    "\n",
    "def train_input_fn(X, y):\n",
    "    dataset = tf.data.Dataset.from_tensor_slices((X, y))\n",
    "    dataset = dataset.shuffle(10000).batch(PARAMS['batch_size']).repeat(PARAMS['num_epochs'])\n",
    "    iterator = dataset.make_one_shot_iterator()\n",
    "    return iterator.get_next()\n",
    "\n",
    "\n",
    "def rnn_cell():\n",
    "    return tf.nn.rnn_cell.GRUCell(PARAMS['n_mfcc'],\n",
    "                                  kernel_initializer=tf.orthogonal_initializer)\n",
    "\n",
    "\n",
    "def clip_grads(loss_op):\n",
    "    variables = tf.trainable_variables()\n",
    "    grads = tf.gradients(loss_op, variables)\n",
    "    clipped_grads, _ = tf.clip_by_global_norm(grads, PARAMS['clip_norm'])\n",
    "    return zip(clipped_grads, variables)\n",
    "\n",
    "\n",
    "def model_fn(features, labels, mode, params):\n",
    "    seq_lens = tf.count_nonzero(tf.reduce_sum(features, -1), 1, dtype=tf.int32)\n",
    "    \n",
    "    outputs, _ = tf.nn.bidirectional_dynamic_rnn(cell_fw = rnn_cell(),\n",
    "                                                 cell_bw = rnn_cell(),\n",
    "                                                 inputs = features,\n",
    "                                                 sequence_length = seq_lens,\n",
    "                                                 dtype = tf.float32,)\n",
    "    outputs = tf.concat(outputs, -1)\n",
    "    logits = tf.layers.dense(outputs, PARAMS['num_classes'])\n",
    "    \n",
    "    time_major = tf.transpose(logits, [1,0,2])\n",
    "    decoded, log_prob = tf.nn.ctc_greedy_decoder(time_major, seq_lens)\n",
    "    decoded = tf.to_int32(decoded[0])\n",
    "    \n",
    "    if mode == tf.estimator.ModeKeys.PREDICT:\n",
    "        preds = tf.sparse_tensor_to_dense(decoded)\n",
    "        return tf.estimator.EstimatorSpec(mode, predictions=preds)\n",
    "    \n",
    "    if mode == tf.estimator.ModeKeys.TRAIN:\n",
    "        loss_op = tf.reduce_mean(tf.nn.ctc_loss(labels, time_major, seq_lens))\n",
    "        edit_dist_op = tf.reduce_mean(tf.edit_distance(decoded, labels))\n",
    "\n",
    "        lth = tf.train.LoggingTensorHook({'edit_dist': edit_dist_op}, every_n_iter=100)\n",
    "        \n",
    "        train_op = tf.train.AdamOptimizer().apply_gradients(clip_grads(loss_op),\n",
    "                                                            global_step=tf.train.get_global_step())\n",
    "        \n",
    "        return tf.estimator.EstimatorSpec(\n",
    "            mode=mode, loss=loss_op, train_op=train_op, training_hooks=[lth])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████| 200/200 [03:03<00:00,  1.09it/s]\n",
      "100%|███████████████████████████████| 200/200 [02:47<00:00,  1.19it/s]\n",
      "100%|███████████████████████████████| 200/200 [03:04<00:00,  1.08it/s]\n",
      "100%|███████████████████████████████| 200/200 [03:02<00:00,  1.10it/s]\n",
      "100%|███████████████████████████████| 200/200 [02:43<00:00,  1.22it/s]\n",
      "100%|███████████████████████████████| 200/200 [02:58<00:00,  1.12it/s]\n",
      "100%|███████████████████████████████| 200/200 [03:19<00:00,  1.00it/s]\n",
      "100%|███████████████████████████████| 200/200 [02:59<00:00,  1.11it/s]\n",
      "100%|███████████████████████████████| 200/200 [03:04<00:00,  1.09it/s]\n",
      "100%|███████████████████████████████| 200/200 [03:22<00:00,  1.01s/it]\n",
      "100%|███████████████████████████████| 200/200 [03:23<00:00,  1.02s/it]\n",
      "100%|███████████████████████████████| 200/200 [02:41<00:00,  1.24it/s]\n",
      "100%|███████████████████████████████| 200/200 [03:12<00:00,  1.04it/s]\n",
      "100%|███████████████████████████████| 200/200 [03:01<00:00,  1.10it/s]\n"
     ]
    }
   ],
   "source": [
    "download()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████| 2800/2800 [00:31<00:00, 89.66it/s]\n"
     ]
    }
   ],
   "source": [
    "wav_files = [f for f in os.listdir('./data') if f.endswith('.wav')]\n",
    "text_files = [f for f in os.listdir('./data') if f.endswith('.txt')]\n",
    "\n",
    "inputs, targets = [], []\n",
    "for (wav_file, text_file) in tqdm(zip(wav_files, text_files), total=len(wav_files), ncols=70):\n",
    "    path = './data/' + wav_file\n",
    "    try:\n",
    "        y, sr = librosa.load(path, sr=None)\n",
    "    except:\n",
    "        continue\n",
    "    inputs.append(librosa.feature.mfcc(y = y,\n",
    "                                       sr = sr,\n",
    "                                       n_mfcc = PARAMS['n_mfcc'],\n",
    "                                       hop_length = int(PARAMS['winstep']*sr)).T)\n",
    "    with open('./data/'+text_file) as f:\n",
    "        targets.append(f.read())\n",
    "\n",
    "inputs = tf.keras.preprocessing.sequence.pad_sequences(\n",
    "    inputs, dtype='float32', padding='post')\n",
    "\n",
    "chars = list(set([c for target in targets for c in target]))\n",
    "PARAMS['num_classes'] = len(chars) + 1\n",
    "\n",
    "idx2char = {idx: char for idx, char in enumerate(chars)}\n",
    "char2idx = {char: idx for idx, char in idx2char.items()}\n",
    "\n",
    "targets = [[char2idx[c] for c in target] for target in targets]\n",
    "\n",
    "inputs_val = np.expand_dims(inputs[-1], 0)\n",
    "targets_val = targets[-1]\n",
    "\n",
    "inputs_train = inputs[:-1]\n",
    "targets_train = targets[:-1]\n",
    "targets_train = tf.SparseTensor(*sparse_tuple_from(targets_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using default config.\n",
      "WARNING:tensorflow:Using temporary folder as model directory: /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpcmpelip6\n",
      "INFO:tensorflow:Using config: {'_model_dir': '/var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpcmpelip6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x1196a5898>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
      "WARNING:tensorflow:Estimator's model_fn (<function model_fn at 0x117795598>) includes params argument, but params are not passed to Estimator.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Create CheckpointSaverHook.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "INFO:tensorflow:Saving checkpoints for 1 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpcmpelip6/model.ckpt.\n",
      "INFO:tensorflow:loss = 374.96756, step = 1\n",
      "INFO:tensorflow:edit_dist = 1.8703568\n",
      "INFO:tensorflow:global_step/sec: 4.23402\n",
      "INFO:tensorflow:loss = 59.0886, step = 101 (23.620 sec)\n",
      "INFO:tensorflow:edit_dist = 0.93818295 (23.620 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.88827\n",
      "INFO:tensorflow:loss = 50.93571, step = 201 (20.457 sec)\n",
      "INFO:tensorflow:edit_dist = 0.91707516 (20.456 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.80715\n",
      "INFO:tensorflow:loss = 45.361046, step = 301 (20.803 sec)\n",
      "INFO:tensorflow:edit_dist = 0.82914376 (20.802 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.72579\n",
      "INFO:tensorflow:loss = 42.25828, step = 401 (21.161 sec)\n",
      "INFO:tensorflow:edit_dist = 0.7802259 (21.162 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.69682\n",
      "INFO:tensorflow:loss = 39.114063, step = 501 (21.291 sec)\n",
      "INFO:tensorflow:edit_dist = 0.72288674 (21.291 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.7023\n",
      "INFO:tensorflow:loss = 36.46131, step = 601 (21.266 sec)\n",
      "INFO:tensorflow:edit_dist = 0.68888825 (21.266 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.7319\n",
      "INFO:tensorflow:loss = 34.118046, step = 701 (21.133 sec)\n",
      "INFO:tensorflow:edit_dist = 0.6578432 (21.133 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.77862\n",
      "INFO:tensorflow:loss = 35.443798, step = 801 (20.926 sec)\n",
      "INFO:tensorflow:edit_dist = 0.60701257 (20.926 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.69831\n",
      "INFO:tensorflow:loss = 33.633713, step = 901 (21.284 sec)\n",
      "INFO:tensorflow:edit_dist = 0.57852817 (21.284 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.71906\n",
      "INFO:tensorflow:loss = 30.519993, step = 1001 (21.190 sec)\n",
      "INFO:tensorflow:edit_dist = 0.5832617 (21.191 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.69079\n",
      "INFO:tensorflow:loss = 32.199757, step = 1101 (21.319 sec)\n",
      "INFO:tensorflow:edit_dist = 0.54267716 (21.318 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.19085\n",
      "INFO:tensorflow:loss = 29.603216, step = 1201 (23.862 sec)\n",
      "INFO:tensorflow:edit_dist = 0.5057326 (23.862 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.35138\n",
      "INFO:tensorflow:loss = 29.17682, step = 1301 (22.981 sec)\n",
      "INFO:tensorflow:edit_dist = 0.5305828 (22.982 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.236\n",
      "INFO:tensorflow:loss = 28.77946, step = 1401 (23.608 sec)\n",
      "INFO:tensorflow:edit_dist = 0.5081019 (23.608 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.41452\n",
      "INFO:tensorflow:loss = 26.51593, step = 1501 (22.652 sec)\n",
      "INFO:tensorflow:edit_dist = 0.4651201 (22.651 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.02982\n",
      "INFO:tensorflow:loss = 28.168192, step = 1601 (24.814 sec)\n",
      "INFO:tensorflow:edit_dist = 0.4389706 (24.814 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.05223\n",
      "INFO:tensorflow:loss = 26.571934, step = 1701 (24.678 sec)\n",
      "INFO:tensorflow:edit_dist = 0.44563484 (24.677 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.16043\n",
      "INFO:tensorflow:loss = 24.21102, step = 1801 (24.036 sec)\n",
      "INFO:tensorflow:edit_dist = 0.39577886 (24.036 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.07595\n",
      "INFO:tensorflow:loss = 24.867167, step = 1901 (24.535 sec)\n",
      "INFO:tensorflow:edit_dist = 0.41833147 (24.535 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.17109\n",
      "INFO:tensorflow:loss = 23.64304, step = 2001 (23.974 sec)\n",
      "INFO:tensorflow:edit_dist = 0.37883988 (23.974 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.13325\n",
      "INFO:tensorflow:loss = 22.543957, step = 2101 (24.197 sec)\n",
      "INFO:tensorflow:edit_dist = 0.37087423 (24.197 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.30542\n",
      "INFO:tensorflow:loss = 22.561909, step = 2201 (23.223 sec)\n",
      "INFO:tensorflow:edit_dist = 0.35443974 (23.223 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.73287\n",
      "INFO:tensorflow:loss = 24.117044, step = 2301 (21.129 sec)\n",
      "INFO:tensorflow:edit_dist = 0.36071625 (21.130 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.95563\n",
      "INFO:tensorflow:loss = 23.403711, step = 2401 (20.179 sec)\n",
      "INFO:tensorflow:edit_dist = 0.37744528 (20.179 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.69119\n",
      "INFO:tensorflow:loss = 21.144855, step = 2501 (21.319 sec)\n",
      "INFO:tensorflow:edit_dist = 0.35170206 (21.319 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.12461\n",
      "INFO:tensorflow:loss = 22.33146, step = 2601 (24.242 sec)\n",
      "INFO:tensorflow:edit_dist = 0.33824176 (24.242 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 2665 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpcmpelip6/model.ckpt.\n",
      "INFO:tensorflow:global_step/sec: 4.11759\n",
      "INFO:tensorflow:loss = 22.036669, step = 2701 (24.287 sec)\n",
      "INFO:tensorflow:edit_dist = 0.32214054 (24.287 sec)\n",
      "INFO:tensorflow:global_step/sec: 5.46808\n",
      "INFO:tensorflow:loss = 20.996695, step = 2801 (18.287 sec)\n",
      "INFO:tensorflow:edit_dist = 0.329262 (18.288 sec)\n",
      "INFO:tensorflow:global_step/sec: 5.91364\n",
      "INFO:tensorflow:loss = 23.102398, step = 2901 (16.910 sec)\n",
      "INFO:tensorflow:edit_dist = 0.34477252 (16.909 sec)\n",
      "INFO:tensorflow:global_step/sec: 5.80056\n",
      "INFO:tensorflow:loss = 20.694818, step = 3001 (17.240 sec)\n",
      "INFO:tensorflow:edit_dist = 0.31150728 (17.240 sec)\n",
      "INFO:tensorflow:global_step/sec: 5.6758\n",
      "INFO:tensorflow:loss = 20.512148, step = 3101 (17.619 sec)\n",
      "INFO:tensorflow:edit_dist = 0.31199473 (17.619 sec)\n",
      "INFO:tensorflow:global_step/sec: 5.89227\n",
      "INFO:tensorflow:loss = 19.161385, step = 3201 (16.973 sec)\n",
      "INFO:tensorflow:edit_dist = 0.3116013 (16.973 sec)\n",
      "INFO:tensorflow:global_step/sec: 5.63655\n",
      "INFO:tensorflow:loss = 19.659054, step = 3301 (17.740 sec)\n",
      "INFO:tensorflow:edit_dist = 0.30624455 (17.740 sec)\n",
      "INFO:tensorflow:global_step/sec: 5.80568\n",
      "INFO:tensorflow:loss = 18.161318, step = 3401 (17.224 sec)\n",
      "INFO:tensorflow:edit_dist = 0.29782137 (17.224 sec)\n",
      "INFO:tensorflow:global_step/sec: 5.9372\n",
      "INFO:tensorflow:loss = 17.332567, step = 3501 (16.843 sec)\n",
      "INFO:tensorflow:edit_dist = 0.2920071 (16.843 sec)\n",
      "INFO:tensorflow:global_step/sec: 6.1187\n",
      "INFO:tensorflow:loss = 18.79929, step = 3601 (16.343 sec)\n",
      "INFO:tensorflow:edit_dist = 0.31719628 (16.343 sec)\n",
      "INFO:tensorflow:global_step/sec: 5.83637\n",
      "INFO:tensorflow:loss = 18.892756, step = 3701 (17.134 sec)\n",
      "INFO:tensorflow:edit_dist = 0.29947785 (17.134 sec)\n",
      "INFO:tensorflow:global_step/sec: 6.08263\n",
      "INFO:tensorflow:loss = 18.513826, step = 3801 (16.440 sec)\n",
      "INFO:tensorflow:edit_dist = 0.2865087 (16.440 sec)\n",
      "INFO:tensorflow:global_step/sec: 6.00706\n",
      "INFO:tensorflow:loss = 18.958822, step = 3901 (16.647 sec)\n",
      "INFO:tensorflow:edit_dist = 0.28152233 (16.647 sec)\n",
      "INFO:tensorflow:global_step/sec: 5.87627\n",
      "INFO:tensorflow:loss = 17.848106, step = 4001 (17.018 sec)\n",
      "INFO:tensorflow:edit_dist = 0.3113426 (17.018 sec)\n",
      "INFO:tensorflow:global_step/sec: 5.83397\n",
      "INFO:tensorflow:loss = 19.42059, step = 4101 (17.141 sec)\n",
      "INFO:tensorflow:edit_dist = 0.29863837 (17.141 sec)\n",
      "INFO:tensorflow:global_step/sec: 6.0269\n",
      "INFO:tensorflow:loss = 21.195824, step = 4201 (16.592 sec)\n",
      "INFO:tensorflow:edit_dist = 0.32390052 (16.592 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.80332\n",
      "INFO:tensorflow:loss = 17.366398, step = 4301 (20.820 sec)\n",
      "INFO:tensorflow:edit_dist = 0.27848583 (20.821 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.00877\n",
      "INFO:tensorflow:loss = 17.95521, step = 4401 (24.945 sec)\n",
      "INFO:tensorflow:edit_dist = 0.2824619 (24.945 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.4139\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:loss = 16.631626, step = 4501 (22.655 sec)\n",
      "INFO:tensorflow:edit_dist = 0.2836084 (22.656 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.25102\n",
      "INFO:tensorflow:loss = 17.992622, step = 4601 (23.524 sec)\n",
      "INFO:tensorflow:edit_dist = 0.2947876 (23.524 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.08951\n",
      "INFO:tensorflow:loss = 17.18179, step = 4701 (24.452 sec)\n",
      "INFO:tensorflow:edit_dist = 0.27750546 (24.452 sec)\n",
      "INFO:tensorflow:global_step/sec: 3.9098\n",
      "INFO:tensorflow:loss = 17.264767, step = 4801 (25.579 sec)\n",
      "INFO:tensorflow:edit_dist = 0.29173332 (25.580 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.06912\n",
      "INFO:tensorflow:loss = 15.911165, step = 4901 (24.574 sec)\n",
      "INFO:tensorflow:edit_dist = 0.26921427 (24.574 sec)\n",
      "INFO:tensorflow:global_step/sec: 3.91084\n",
      "INFO:tensorflow:loss = 16.283234, step = 5001 (25.569 sec)\n",
      "INFO:tensorflow:edit_dist = 0.27881262 (25.569 sec)\n",
      "INFO:tensorflow:global_step/sec: 3.89122\n",
      "INFO:tensorflow:loss = 41.074135, step = 5101 (25.699 sec)\n",
      "INFO:tensorflow:edit_dist = 0.49321896 (25.699 sec)\n",
      "INFO:tensorflow:global_step/sec: 3.86929\n",
      "INFO:tensorflow:loss = 29.585585, step = 5201 (25.844 sec)\n",
      "INFO:tensorflow:edit_dist = 0.38033772 (25.844 sec)\n",
      "INFO:tensorflow:global_step/sec: 3.91547\n",
      "INFO:tensorflow:loss = 23.962727, step = 5301 (25.540 sec)\n",
      "INFO:tensorflow:edit_dist = 0.36094627 (25.540 sec)\n",
      "INFO:tensorflow:global_step/sec: 3.85529\n",
      "INFO:tensorflow:loss = 22.436602, step = 5401 (25.938 sec)\n",
      "INFO:tensorflow:edit_dist = 0.33216232 (25.938 sec)\n",
      "INFO:tensorflow:global_step/sec: 3.99989\n",
      "INFO:tensorflow:loss = 20.587572, step = 5501 (25.001 sec)\n",
      "INFO:tensorflow:edit_dist = 0.3263671 (25.001 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 5556 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpcmpelip6/model.ckpt.\n",
      "INFO:tensorflow:global_step/sec: 3.80259\n",
      "INFO:tensorflow:loss = 17.250906, step = 5601 (26.298 sec)\n",
      "INFO:tensorflow:edit_dist = 0.27117378 (26.299 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.31421\n",
      "INFO:tensorflow:loss = 18.99039, step = 5701 (23.181 sec)\n",
      "INFO:tensorflow:edit_dist = 0.30856752 (23.181 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.2908\n",
      "INFO:tensorflow:loss = 15.75772, step = 5801 (23.303 sec)\n",
      "INFO:tensorflow:edit_dist = 0.27979305 (23.303 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.45054\n",
      "INFO:tensorflow:loss = 16.252146, step = 5901 (22.469 sec)\n",
      "INFO:tensorflow:edit_dist = 0.26365742 (22.469 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.18691\n",
      "INFO:tensorflow:loss = 16.698246, step = 6001 (23.885 sec)\n",
      "INFO:tensorflow:edit_dist = 0.26505634 (23.886 sec)\n",
      "INFO:tensorflow:global_step/sec: 3.93354\n",
      "INFO:tensorflow:loss = 16.946568, step = 6101 (25.422 sec)\n",
      "INFO:tensorflow:edit_dist = 0.2820779 (25.422 sec)\n",
      "INFO:tensorflow:global_step/sec: 3.97578\n",
      "INFO:tensorflow:loss = 16.34867, step = 6201 (25.152 sec)\n",
      "INFO:tensorflow:edit_dist = 0.26811287 (25.150 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.06189\n",
      "INFO:tensorflow:loss = 15.72277, step = 6301 (24.620 sec)\n",
      "INFO:tensorflow:edit_dist = 0.24580683 (24.620 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.10268\n",
      "INFO:tensorflow:loss = 15.212906, step = 6401 (24.374 sec)\n",
      "INFO:tensorflow:edit_dist = 0.2581456 (24.374 sec)\n",
      "INFO:tensorflow:global_step/sec: 3.6194\n",
      "INFO:tensorflow:loss = 14.954583, step = 6501 (27.629 sec)\n",
      "INFO:tensorflow:edit_dist = 0.24238834 (27.629 sec)\n",
      "INFO:tensorflow:global_step/sec: 3.80485\n",
      "INFO:tensorflow:loss = 14.875512, step = 6601 (26.282 sec)\n",
      "INFO:tensorflow:edit_dist = 0.24487744 (26.282 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.05539\n",
      "INFO:tensorflow:loss = 15.2634735, step = 6701 (24.659 sec)\n",
      "INFO:tensorflow:edit_dist = 0.22339396 (24.659 sec)\n",
      "INFO:tensorflow:global_step/sec: 3.92058\n",
      "INFO:tensorflow:loss = 16.228151, step = 6801 (25.507 sec)\n",
      "INFO:tensorflow:edit_dist = 0.25901487 (25.506 sec)\n",
      "INFO:tensorflow:global_step/sec: 3.91818\n",
      "INFO:tensorflow:loss = 18.399601, step = 6901 (25.521 sec)\n",
      "INFO:tensorflow:edit_dist = 0.21872278 (25.521 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.05256\n",
      "INFO:tensorflow:loss = 15.993952, step = 7001 (24.676 sec)\n",
      "INFO:tensorflow:edit_dist = 0.24000403 (24.676 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.45084\n",
      "INFO:tensorflow:loss = 14.157624, step = 7101 (22.468 sec)\n",
      "INFO:tensorflow:edit_dist = 0.21493807 (22.468 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.29698\n",
      "INFO:tensorflow:loss = 18.547455, step = 7201 (23.272 sec)\n",
      "INFO:tensorflow:edit_dist = 0.26891413 (23.272 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.42466\n",
      "INFO:tensorflow:loss = 14.631651, step = 7301 (22.602 sec)\n",
      "INFO:tensorflow:edit_dist = 0.23264678 (22.604 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.46825\n",
      "INFO:tensorflow:loss = 17.281555, step = 7401 (22.378 sec)\n",
      "INFO:tensorflow:edit_dist = 0.2436554 (22.376 sec)\n",
      "INFO:tensorflow:global_step/sec: 5.87834\n",
      "INFO:tensorflow:loss = 14.201106, step = 7501 (17.011 sec)\n",
      "INFO:tensorflow:edit_dist = 0.21492375 (17.012 sec)\n",
      "INFO:tensorflow:global_step/sec: 6.28769\n",
      "INFO:tensorflow:loss = 13.416147, step = 7601 (15.904 sec)\n",
      "INFO:tensorflow:edit_dist = 0.22045279 (15.904 sec)\n",
      "INFO:tensorflow:global_step/sec: 6.25889\n",
      "INFO:tensorflow:loss = 13.725887, step = 7701 (15.977 sec)\n",
      "INFO:tensorflow:edit_dist = 0.22992562 (15.977 sec)\n",
      "INFO:tensorflow:global_step/sec: 6.25028\n",
      "INFO:tensorflow:loss = 14.554529, step = 7801 (15.999 sec)\n",
      "INFO:tensorflow:edit_dist = 0.2360595 (15.999 sec)\n",
      "INFO:tensorflow:global_step/sec: 6.17141\n",
      "INFO:tensorflow:loss = 14.297422, step = 7901 (16.205 sec)\n",
      "INFO:tensorflow:edit_dist = 0.23541954 (16.205 sec)\n",
      "INFO:tensorflow:global_step/sec: 6.20114\n",
      "INFO:tensorflow:loss = 13.266538, step = 8001 (16.125 sec)\n",
      "INFO:tensorflow:edit_dist = 0.2166674 (16.125 sec)\n",
      "INFO:tensorflow:global_step/sec: 6.18092\n",
      "INFO:tensorflow:loss = 24.518171, step = 8101 (16.179 sec)\n",
      "INFO:tensorflow:edit_dist = 0.28858146 (16.179 sec)\n",
      "INFO:tensorflow:global_step/sec: 6.19879\n",
      "INFO:tensorflow:loss = 18.830341, step = 8201 (16.132 sec)\n",
      "INFO:tensorflow:edit_dist = 0.2930147 (16.132 sec)\n",
      "INFO:tensorflow:global_step/sec: 6.28011\n",
      "INFO:tensorflow:loss = 16.537518, step = 8301 (15.923 sec)\n",
      "INFO:tensorflow:edit_dist = 0.24828431 (15.923 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 8320 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpcmpelip6/model.ckpt.\n",
      "INFO:tensorflow:global_step/sec: 5.34741\n",
      "INFO:tensorflow:loss = 17.05547, step = 8401 (18.701 sec)\n",
      "INFO:tensorflow:edit_dist = 0.26148695 (18.701 sec)\n",
      "INFO:tensorflow:global_step/sec: 6.23188\n",
      "INFO:tensorflow:loss = 16.175426, step = 8501 (16.047 sec)\n",
      "INFO:tensorflow:edit_dist = 0.25617474 (16.047 sec)\n",
      "INFO:tensorflow:global_step/sec: 6.19495\n",
      "INFO:tensorflow:loss = 16.301945, step = 8601 (16.142 sec)\n",
      "INFO:tensorflow:edit_dist = 0.24738564 (16.142 sec)\n",
      "INFO:tensorflow:global_step/sec: 5.87057\n",
      "INFO:tensorflow:loss = 16.246786, step = 8701 (17.035 sec)\n",
      "INFO:tensorflow:edit_dist = 0.24595588 (17.035 sec)\n",
      "INFO:tensorflow:global_step/sec: 6.21705\n",
      "INFO:tensorflow:loss = 18.202633, step = 8801 (16.084 sec)\n",
      "INFO:tensorflow:edit_dist = 0.26606756 (16.084 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.9914\n",
      "INFO:tensorflow:loss = 17.892023, step = 8901 (20.035 sec)\n",
      "INFO:tensorflow:edit_dist = 0.2648693 (20.035 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.30781\n",
      "INFO:tensorflow:loss = 16.812265, step = 9001 (23.214 sec)\n",
      "INFO:tensorflow:edit_dist = 0.25671154 (23.215 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.1313\n",
      "INFO:tensorflow:loss = 15.238453, step = 9101 (24.205 sec)\n",
      "INFO:tensorflow:edit_dist = 0.23532136 (24.204 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.4606\n",
      "INFO:tensorflow:loss = 15.169007, step = 9201 (22.419 sec)\n",
      "INFO:tensorflow:edit_dist = 0.22878686 (22.418 sec)\n",
      "INFO:tensorflow:global_step/sec: 4.46589\n",
      "INFO:tensorflow:loss = 14.968628, step = 9301 (22.393 sec)\n",
      "INFO:tensorflow:edit_dist = 0.23009259 (22.393 sec)\n",
      "INFO:tensorflow:Saving checkpoints for 9400 into /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpcmpelip6/model.ckpt.\n",
      "INFO:tensorflow:Loss for final step: 13.807625.\n",
      "INFO:tensorflow:Calling model_fn.\n",
      "INFO:tensorflow:Done calling model_fn.\n",
      "INFO:tensorflow:Graph was finalized.\n",
      "INFO:tensorflow:Restoring parameters from /var/folders/sx/fv0r97j96fz8njp14dt5g7940000gn/T/tmpcmpelip6/model.ckpt-9400\n",
      "INFO:tensorflow:Running local_init_op.\n",
      "INFO:tensorflow:Done running local_init_op.\n",
      "\n",
      "Ground Truth: say the word youth\n",
      "\n",
      "Prediction: say the word e\n"
     ]
    }
   ],
   "source": [
    "estimator = tf.estimator.Estimator(model_fn)\n",
    "\n",
    "estimator.train(lambda: train_input_fn(inputs_train, targets_train))\n",
    "\n",
    "preds = list(estimator.predict(tf.estimator.inputs.numpy_input_fn(inputs_val, shuffle=False)))\n",
    "\n",
    "print()\n",
    "print('Ground Truth:', ''.join([idx2char[idx] for idx in targets_val]))\n",
    "print()\n",
    "print('Prediction:', ''.join([idx2char[idx] for idx in preds[0]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
